package babysitter

import (
	"context"
	"crypto/sha256"
	_ "embed"
	"errors"
	"fmt"
	"math"
	"net"
	"net/http"
	"path"
	"sort"
	"sync"
	"syscall"

	"github.com/google/uuid"
	"go.opentelemetry.io/otel/sdk/trace"
	"golang.org/x/exp/maps"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/known/timestamppb"
	"greatestworks/aop/envelope"
	"greatestworks/aop/logging"
	"greatestworks/aop/logtype"
	"greatestworks/aop/metrics"
	"greatestworks/aop/perfetto"
	"greatestworks/aop/protomsg"
	"greatestworks/aop/protos"
	"greatestworks/aop/proxy"
	"greatestworks/aop/retry"
	"greatestworks/aop/status"
	"greatestworks/aop/versioned_map"
)

const (
	// The default replication factor for a component.
	DefaultReplication = 2

	// routingInfoKey is the key where we track routing information for a
	// given colocation group.
	routingInfoKey = "routing_entries"

	// appVersionStateKey is the key where we track the state for a given application version.
	appVersionStateKey = "app_version_state"
)

// Babysitter manages an application version deployment.
type Babysitter struct {
	ctx    context.Context
	opts   envelope.Options
	dep    *protos.Deployment
	logger logtype.Logger

	// logSaver processes log entries generated by the weavelet. The entries
	// either have the timestamp produced by the weavelet, or have a nil Time
	// field. Defaults to a log saver that pretty prints log entries to stderr.
	//
	// logSaver is called concurrently from multiple goroutines, so it should
	// be thread safe.
	logSaver func(*protos.LogEntry)

	// traceSaver processes trace spans generated by the weavelet. If nil,
	// weavelet traces are dropped.
	//
	// traceSaver is called concurrently from multiple goroutines, so it should
	// be thread safe.
	traceSaver func([]trace.ReadOnlySpan) error

	// statsProcessor tracks and computes stats to be rendered on the /statusz page.
	statsProcessor *metrics.StatsProcessor

	mu           sync.RWMutex
	managed      map[string][]*envelope.Envelope // replica envelopes, by group
	appState     *versioned_map.Map[*AppVersionState]
	routingState *versioned_map.Map[*protos.RoutingInfo]
	proxies      map[string]*proxyInfo // proxies, by listener name
}

type proxyInfo struct {
	proxy *proxy.Proxy
	addr  string // dialable address of the proxy
}

var _ envelope.EnvelopeHandler = &Babysitter{}

// NewBabysitter creates a new babysitter.
func NewBabysitter(ctx context.Context, dep *protos.Deployment, logSaver func(*protos.LogEntry)) (*Babysitter, error) {
	logger := logging.FuncLogger{
		Opts: logging.Options{
			App:       dep.App.Name,
			Component: "babysitter",
			Weavelet:  uuid.NewString(),
			Attrs:     []string{"serviceweaver/system", ""},
		},
		Write: logSaver,
	}

	// Create the trace saver.
	traceDB, err := perfetto.Open(ctx)
	if err != nil {
		return nil, fmt.Errorf("cannot open Perfetto database: %w", err)
	}
	traceSaver := func(spans []trace.ReadOnlySpan) error {
		return traceDB.Store(ctx, dep.App.Name, dep.Id, spans)
	}

	b := &Babysitter{
		ctx:            ctx,
		logger:         logger,
		logSaver:       logSaver,
		traceSaver:     traceSaver,
		statsProcessor: metrics.NewStatsProcessor(),
		opts:           envelope.Options{Restart: envelope.Never, Retry: retry.DefaultOptions},
		dep:            dep,
		managed:        map[string][]*envelope.Envelope{},
		appState:       versioned_map.NewMap[*AppVersionState](),
		routingState:   versioned_map.NewMap[*protos.RoutingInfo](),
		proxies:        map[string]*proxyInfo{},
	}
	go b.statsProcessor.CollectMetrics(b.ctx, b.readMetrics)
	return b, nil
}

// RegisterStatusPages registers the status pages with the provided mux.
func (b *Babysitter) RegisterStatusPages(mux *http.ServeMux) {
	status.RegisterServer(mux, b, b.logger)
}

func (b *Babysitter) startColocationGroup(group *protos.ColocationGroup) error {
	envelopes, ok := b.managed[group.Name]
	if ok && len(envelopes) == DefaultReplication {
		// Already started.
		return nil
	}

	for r := 0; r < DefaultReplication; r++ {
		// Note that we assign a unique UUID for each group replica. This is because
		// we use the group replica ids to create replica-local addresses to
		// communicate between the weavelets.
		id := uuid.NewHash(sha256.New(), uuid.Nil, []byte(fmt.Sprintf("%d", r)), 0).String()

		// Start the weavelet and capture its logs, traces, and metrics.
		wlet := &protos.WeaveletInfo{
			App:           b.dep.App.Name,
			DeploymentId:  b.dep.Id,
			Group:         group,
			GroupId:       id,
			Id:            uuid.New().String(),
			SameProcess:   b.dep.App.SameProcess,
			Sections:      b.dep.App.Sections,
			SingleProcess: b.dep.SingleProcess,
			SingleMachine: true,
		}
		e, err := envelope.NewEnvelope(wlet, b.dep.App, b, b.opts)
		if err != nil {
			return err
		}
		go func() {
			// TODO(mwhittaker): Propagate errors.
			if err := e.Run(b.ctx); err != nil {
				b.logger.Error("e.Run", err)
			}
		}()
		b.managed[group.Name] = append(b.managed[group.Name], e)
	}
	return nil
}

// StartComponent implements the protos.EnvelopeHandler interface.
func (b *Babysitter) StartComponent(req *protos.ComponentToStart) error {
	b.mu.Lock()
	defer b.mu.Unlock()

	// Load app state.
	state, _, err := b.loadAppState("" /*version*/)
	if err != nil {
		return err
	}
	g := b.findOrAddGroup(state, req.ColocationGroup)

	// Update routing information.
	g.Components[req.Component] = req.IsRouted
	if req.IsRouted {
		if _, ok := g.Assignments[req.Component]; !ok {
			// Create an initial assignment for the component.
			g.Assignments[req.Component] = &protos.Assignment{
				App:          b.dep.App.Name,
				DeploymentId: b.dep.Id,
				Component:    req.Component,
			}
		}
	}
	if err := b.mayGenerateNewRoutingInfo(g); err != nil {
		return err
	}

	// Store app state
	b.appState.Update(appVersionStateKey, state)

	// Start the colocation group, if it hasn't already started.
	return b.startColocationGroup(&protos.ColocationGroup{Name: req.ColocationGroup})
}

// RegisterReplica implements the protos.EnvelopeHandler interface.
func (b *Babysitter) RegisterReplica(req *protos.ReplicaToRegister) error {
	b.mu.Lock()
	defer b.mu.Unlock()

	// Load app state.
	state, _, err := b.loadAppState("" /*version*/)
	if err != nil {
		return err
	}
	g := b.findOrAddGroup(state, req.Group)

	// Append the replica, if not already appended.
	var found bool
	for _, replica := range g.Replicas {
		if req.Address == replica {
			found = true
			break
		}
	}
	if !found {
		g.Replicas = append(g.Replicas, req.Address)
		g.ReplicaPids = append(g.ReplicaPids, req.Pid)
	}

	// Generate routing info, now that the replica set has changed.
	if err := b.mayGenerateNewRoutingInfo(g); err != nil {
		return err
	}

	// Store app state.
	b.appState.Update(appVersionStateKey, state)
	return nil
}

// GetComponentsToStart implements the protos.EnvelopeHandler interface.
func (b *Babysitter) GetComponentsToStart(req *protos.GetComponentsToStart) (*protos.ComponentsToStart, error) {
	// Load app state.
	state, newVersion, err := b.loadAppState(req.Version)
	if err != nil {
		return nil, err
	}
	g := b.findOrAddGroup(state, req.Group)

	// Return the components.
	var reply protos.ComponentsToStart
	reply.Version = newVersion
	reply.Components = maps.Keys(g.Components)
	return &reply, nil
}

// RecvLogEntry implements the protos.EnvelopeHandler interface.
func (b *Babysitter) RecvLogEntry(entry *protos.LogEntry) {
	b.logSaver(entry)
}

// RecvTraceSpans implements the protos.EnvelopeHandler interface.
func (b *Babysitter) RecvTraceSpans(spans []trace.ReadOnlySpan) error {
	if b.traceSaver == nil {
		return nil
	}
	return b.traceSaver(spans)
}

// ReportLoad implements the protos.EnvelopeHandler interface.
func (b *Babysitter) ReportLoad(*protos.WeaveletLoadReport) error {
	return nil
}

// GetAddress implements the protos.EnvelopeHandler interface.
func (b *Babysitter) GetAddress(req *protos.GetAddressRequest) (*protos.GetAddressReply, error) {
	return &protos.GetAddressReply{Address: "localhost:0"}, nil
}

// ExportListener implements the protos.EnvelopeHandler interface.
func (b *Babysitter) ExportListener(req *protos.ExportListenerRequest) (*protos.ExportListenerReply, error) {
	b.mu.Lock()
	defer b.mu.Unlock()

	// Load app state.
	state, _, err := b.loadAppState("" /*version*/)
	if err != nil {
		return nil, err
	}

	// Update and store the state.
	state.Listeners = append(state.Listeners, req.Listener)
	b.appState.Update(appVersionStateKey, state)

	// Update the proxy.
	if p, ok := b.proxies[req.Listener.Name]; ok {
		p.proxy.AddBackend(req.Listener.Addr)
		return &protos.ExportListenerReply{ProxyAddress: p.addr}, nil
	}

	lis, err := net.Listen("tcp", req.LocalAddress)
	if errors.Is(err, syscall.EADDRINUSE) {
		// Don't retry if this address is already in use.
		return &protos.ExportListenerReply{Error: err.Error()}, nil
	}
	if err != nil {
		return nil, fmt.Errorf("proxy listen: %w", err)
	}
	addr := lis.Addr().String()
	b.logger.Info("Proxy listening", "address", addr)
	proxy := proxy.NewProxy(b.logger)
	proxy.AddBackend(req.Listener.Addr)
	b.proxies[req.Listener.Name] = &proxyInfo{proxy: proxy, addr: addr}
	go func() {
		if err := serveHTTP(b.ctx, lis, proxy); err != nil {
			b.logger.Error("proxy", err)
		}
	}()
	return &protos.ExportListenerReply{ProxyAddress: addr}, nil
}

// GetRoutingInfo implements the protos.EnvelopeHandler interface.
func (b *Babysitter) GetRoutingInfo(req *protos.GetRoutingInfo) (*protos.RoutingInfo, error) {
	state, newVersion, err := b.loadRoutingState(req.Group, req.Version)
	if err != nil {
		return nil, err
	}
	state.Version = newVersion
	return state, nil
}

func (b *Babysitter) loadAppState(version string) (*AppVersionState, string, error) {
	state, newVersion, err := b.appState.Read(b.ctx, appVersionStateKey, version)
	if err != nil {
		return nil, "", err
	}
	if state == nil {
		state = &AppVersionState{
			App:            b.dep.App.Name,
			DeploymentId:   b.dep.Id,
			SubmissionTime: timestamppb.Now(),
		}
	}
	// TODO(spetrovic): Versioned map stores empty maps as nil maps.
	// This means that it's not enough to initialize empty maps when
	// creating the new AppVersionState above.
	if state.Groups == nil {
		state.Groups = map[string]*ColocationGroupState{}
	}
	return state, newVersion, nil
}

func (b *Babysitter) findOrAddGroup(state *AppVersionState, group string) *ColocationGroupState {
	g := state.Groups[group]
	if g == nil {
		g = &ColocationGroupState{
			Name: group,
		}
		state.Groups[group] = g
	}
	// TODO(spetrovic): Versioned map stores empty maps as nil maps.
	// This means that it's not enough to initialize empty maps when
	// creating the new ColocationGroupState above.
	if g.Components == nil {
		g.Components = map[string]bool{}
	}
	if g.Assignments == nil {
		g.Assignments = map[string]*protos.Assignment{}
	}
	return g
}

func (b *Babysitter) getEnvelopes() []*envelope.Envelope {
	b.mu.RLock()
	defer b.mu.RUnlock()

	var envelopes []*envelope.Envelope
	for _, envs := range b.managed {
		envelopes = append(envelopes, envs...)
	}
	return envelopes
}

func (b *Babysitter) getManagedProcesses() map[string][]*envelope.Envelope {
	b.mu.RLock()
	defer b.mu.RUnlock()
	res := map[string][]*envelope.Envelope{}
	for proc, envs := range b.managed {
		res[proc] = append(res[proc], envs...)
	}
	return res
}

func (b *Babysitter) readMetrics() []*metrics.MetricSnapshot {
	var ms []*metrics.MetricSnapshot
	for _, e := range b.getEnvelopes() {
		m, err := e.ReadMetrics()
		if err != nil {
			continue
		}
		ms = append(ms, m...)
	}
	return append(ms, metrics.Snapshot()...)
}

// Profile implements the status.Server interface.
func (b *Babysitter) Profile(_ context.Context, req *protos.RunProfiling) (*protos.Profile, error) {
	profile, err := runProfiling(b.ctx, req, b.getManagedProcesses())
	if err != nil {
		return nil, err
	}
	profile.AppName = b.dep.App.Name
	profile.VersionId = b.dep.Id
	return profile, nil
}

// Status implements the status.Server interface.
func (b *Babysitter) Status(ctx context.Context) (*status.Status, error) {
	state, _, err := b.loadAppState("" /*version*/)
	if err != nil {
		return nil, err
	}

	stats := b.statsProcessor.GetStatsStatusz()
	var components []*status.Component
	for _, g := range state.Groups {
		for component := range g.Components {
			c := &status.Component{
				Name:  component,
				Group: g.Name,
				Pids:  g.ReplicaPids,
			}
			components = append(components, c)

			// TODO(mwhittaker): Unify with ui package and remove duplication.
			s := stats[logging.ShortenComponent(component)]
			if s == nil {
				continue
			}
			for _, methodStats := range s {
				c.Methods = append(c.Methods, &status.Method{
					Name: methodStats.Name,
					Minute: &status.MethodStats{
						NumCalls:     methodStats.Minute.NumCalls,
						AvgLatencyMs: methodStats.Minute.AvgLatencyMs,
						RecvKbPerSec: methodStats.Minute.RecvKBPerSec,
						SentKbPerSec: methodStats.Minute.SentKBPerSec,
					},
					Hour: &status.MethodStats{
						NumCalls:     methodStats.Hour.NumCalls,
						AvgLatencyMs: methodStats.Hour.AvgLatencyMs,
						RecvKbPerSec: methodStats.Hour.RecvKBPerSec,
						SentKbPerSec: methodStats.Hour.SentKBPerSec,
					},
					Total: &status.MethodStats{
						NumCalls:     methodStats.Total.NumCalls,
						AvgLatencyMs: methodStats.Total.AvgLatencyMs,
						RecvKbPerSec: methodStats.Total.RecvKBPerSec,
						SentKbPerSec: methodStats.Total.SentKBPerSec,
					},
				})
			}
		}
	}

	b.mu.Lock()
	defer b.mu.Unlock()
	var listeners []*status.Listener
	for name, proxy := range b.proxies {
		listeners = append(listeners, &status.Listener{
			Name: name,
			Addr: proxy.addr,
		})
	}

	return &status.Status{
		App:            state.App,
		DeploymentId:   state.DeploymentId,
		SubmissionTime: state.SubmissionTime,
		Components:     components,
		Listeners:      listeners,
		Config:         b.dep.App,
	}, nil
}

// Metrics implements the status.Server interface.
func (b *Babysitter) Metrics(ctx context.Context) (*status.Metrics, error) {
	m := &status.Metrics{}
	for _, snap := range b.readMetrics() {
		m.Metrics = append(m.Metrics, snap.ToProto())
	}
	return m, nil
}

// mayGenerateNewRoutingInfo may generate new routing information for a given
// colocation group.
//
// This method is called whenever (1) the colocation group starts managing
// new routed components, or (2) a new replica of the colocation group gets
// started.
//
// REQUIRES: b.mu is held.
func (b *Babysitter) mayGenerateNewRoutingInfo(g *ColocationGroupState) error {
	for component, assignment := range g.Assignments {
		newAssignment, err := routingAlgo(assignment, g.Replicas)
		if err != nil || newAssignment == nil {
			continue // don't update assignments
		}
		g.Assignments[component] = newAssignment
	}

	// Update the routing information.
	sort.Strings(g.Replicas)
	info := protos.RoutingInfo{
		Replicas: g.Replicas,
	}
	for _, assignment := range g.Assignments {
		info.Assignments = append(info.Assignments, assignment)
	}
	return b.updateRoutingInfo(g, &info)
}

// updateRoutingInfo update the state with the latest routing info for a
// colocation group.
// REQUIRES: b.mu is held.
func (b *Babysitter) updateRoutingInfo(g *ColocationGroupState, info *protos.RoutingInfo) error {
	state, _, err := b.loadRoutingState(g.Name, "" /*version*/)
	if err != nil {
		return err
	}
	if proto.Equal(state, info) { // Nothing to update
		return nil
	}
	b.routingState.Update(routingKey(g.Name), info)
	return nil
}

func (b *Babysitter) loadRoutingState(group, version string) (*protos.RoutingInfo, string, error) {
	state, newVersion, err := b.routingState.Read(b.ctx, routingKey(group), version)
	if err != nil {
		return nil, "", err
	}
	if state == nil {
		state = &protos.RoutingInfo{}
	}
	return state, newVersion, nil
}

// routingAlgo is an implementation of a routing algorithm that distributes the
// entire key space approximately equally across all healthy resources.
//
// The algorithm is as follows:
// - split the entire key space in a number of slices that is more likely to
// spread uniformly the key space among all healthy resources
//
// - distribute the slices round robin across all healthy resources
func routingAlgo(currAssignment *protos.Assignment, candidates []string) (*protos.Assignment, error) {
	newAssignment := protomsg.Clone(currAssignment)
	newAssignment.Version++

	// Note that the healthy resources should be sorted. This is required because
	// we want to do a deterministic assignment of slices to resources among
	// different invocations, to avoid unnecessary churn while generating
	// new assignments.
	sort.Strings(candidates)

	if len(candidates) == 0 {
		newAssignment.Slices = nil
		return newAssignment, nil
	}

	const minSliceKey = 0
	const maxSliceKey = math.MaxUint64

	// If there is only one healthy resource, assign the entire key space to it.
	if len(candidates) == 1 {
		newAssignment.Slices = []*protos.Assignment_Slice{
			{Start: minSliceKey, Replicas: candidates},
		}
		return newAssignment, nil
	}

	// Compute the total number of slices in the assignment.
	numSlices := nextPowerOfTwo(len(candidates))

	// Split slices in equal subslices in order to generate numSlices.
	splits := [][]uint64{{minSliceKey, maxSliceKey}}
	var curr []uint64
	for ok := true; ok; ok = len(splits) != numSlices {
		curr, splits = splits[0], splits[1:]
		midPoint := curr[0] + uint64(math.Floor(0.5*float64(curr[1]-curr[0])))
		splitl := []uint64{curr[0], midPoint}
		splitr := []uint64{midPoint, curr[1]}
		splits = append(splits, splitl, splitr)
	}

	// Sort the computed slices in increasing order based on the start key, in
	// order to provide a deterministic assignment across multiple runs, hence to
	// minimize churn.
	sort.Slice(splits, func(i, j int) bool {
		return splits[i][0] <= splits[j][0]
	})

	// Assign the computed slices to resources in a round robin fashion.
	slices := make([]*protos.Assignment_Slice, len(splits))
	rId := 0
	for i, s := range splits {
		slices[i] = &protos.Assignment_Slice{
			Start:    s[0],
			Replicas: []string{candidates[rId]},
		}
		rId = (rId + 1) % len(candidates)
	}
	newAssignment.Slices = slices
	return newAssignment, nil
}

// serveHTTP serves HTTP traffic on the provided listener using the provided
// handler. The server is shut down when then provided context is cancelled.
func serveHTTP(ctx context.Context, lis net.Listener, handler http.Handler) error {
	server := http.Server{Handler: handler}
	errs := make(chan error, 1)
	go func() { errs <- server.Serve(lis) }()
	select {
	case err := <-errs:
		return err
	case <-ctx.Done():
		return server.Shutdown(ctx)
	}
}

// nextPowerOfTwo returns the next power of 2 that is greater or equal to x.
func nextPowerOfTwo(x int) int {
	// If x is already power of 2, return x.
	if x&(x-1) == 0 {
		return x
	}
	return int(math.Pow(2, math.Ceil(math.Log2(float64(x)))))
}

func routingKey(group string) string {
	return path.Join(routingInfoKey, group)
}
